#include <cstdio>
#include <cstdlib>
#include <cassert>
#include <tuple>
#include <sys/cdefs.h>
#include "linalg.hpp"
#include "esmd_types.h"
#include "rigid.hpp"
// int nrec = 4;

#include <string.h>

struct lincs {
  int nrec = 4;
  template <typename... Ts>
  static __always_inline void eval_all(Ts... args) {
  }
  template <class T, int... IAs, int... Is, int... Js, int... ICs, int... ICMs, int... JCMs>
  __always_inline void do_lincs_static(vec<real> *ox, vec<real> *oxp, real *oinvmass, rigid_rec &rec,
                                              int *idx,
                                              utils::iseq<int, IAs...> iaseq,
                                              utils::iseq<int, Is...> iseq,
                                              utils::iseq<int, Js...> jseq,
                                              utils::iseq<int, ICs...> icseq,
                                              utils::iseq<int, ICMs...> icmatseq,
                                              utils::iseq<int, JCMs...> jcmatseq) {
    vec<real> x[] = {ox[idx[IAs]]...};
    vec<real> xp[] = {oxp[idx[IAs]]...};
    real invmass[] = {oinvmass[idx[IAs]]...};
    vec<real> B[] = {(x[Is] - x[Js]).unit_vec()...};
    real Sdiag[] = {1.0 / sqrt(invmass[Is] + invmass[Js])...};
    real r0[] = {rec.r0[ICs]...};
    real rhs[] = {Sdiag[ICs] * (B[ICs].dot(xp[Is] - xp[Js]) - r0[ICs])...};
    real sol[] = {rhs[ICs]...};
    real A[T::ncons * T::ncons] = {
        ICMs == JCMs ? 0 : T::sign(ICMs, JCMs) * invmass[T::conn(ICMs, JCMs)] * Sdiag[ICMs] * Sdiag[JCMs] * B[ICMs].dot(B[JCMs])...};
    for (int ii = 0; ii < nrec; ii++) {
      real rhsnew[T::ncons] = {tdot<T::ncons>(A + ICs * T::ncons, rhs)...};
      eval_all(rhs[ICs] = rhsnew[ICs]...);
      eval_all(sol[ICs] += rhsnew[ICs]...);
    }

    eval_all(xp[Is] -= B[ICs] * invmass[Is] * Sdiag[ICs] * sol[ICs]...);
    eval_all(xp[Js] += B[ICs] * invmass[Js] * Sdiag[ICs] * sol[ICs]...);

    real rhscorr[] = {Sdiag[ICs] * (r0[ICs] - sqrt(2.0 * r0[ICs] * r0[ICs] - (xp[Is] - xp[Js]).norm2()))...};
    real solcorr[] = {rhscorr[ICs]...};
    for (int ii = 0; ii < nrec; ii++) {
      real rhsnew[T::ncons] = {tdot<T::ncons>(A + ICs * T::ncons, rhscorr)...};
      eval_all(rhscorr[ICs] = rhsnew[ICs]...);
      eval_all(solcorr[ICs] += rhsnew[ICs]...);
    }

    eval_all(xp[Is] -= B[ICs] * invmass[Is] * Sdiag[ICs] * solcorr[ICs]...);
    eval_all(xp[Js] += B[ICs] * invmass[Js] * Sdiag[ICs] * solcorr[ICs]...);
    eval_all(oxp[idx[IAs]] = xp[IAs]...);
  }

  template <class T, int... IAs, int... Is, int... Js, int... ICs, int... ICMs, int... JCMs>
  __always_inline void do_lincs(vec<real> *f, real invdtfsq,
                                       vec<real> *ox, vec<real> *oxp, real *oinvmass,
                                       rigid_rec &rec,
                                       int *idx,
                                       utils::iseq<int, IAs...> iaseq,
                                       utils::iseq<int, Is...> iseq,
                                       utils::iseq<int, Js...> jseq,
                                       utils::iseq<int, ICs...> icseq,
                                       utils::iseq<int, ICMs...> icmatseq,
                                       utils::iseq<int, JCMs...> jcmatseq) {
    vec<real> x[] = {ox[idx[IAs]]...};
    vec<real> xp[] = {oxp[idx[IAs]]...};
    real invmass[] = {oinvmass[idx[IAs]]...};
    vec<real> B[] = {(x[Is] - x[Js]).unit_vec()...};
    real Sdiag[] = {1.0 / sqrt(invmass[Is] + invmass[Js])...};
    real r0[] = {rec.r0[ICs]...};
    real rhs[] = {Sdiag[ICs] * (B[ICs].dot(xp[Is] - xp[Js]) - r0[ICs])...};
    real sol[] = {rhs[ICs]...};
    real A[T::ncons * T::ncons] = {
        ICMs == JCMs ? 0 : T::sign(ICMs, JCMs) * invmass[T::conn(ICMs, JCMs)] * Sdiag[ICMs] * Sdiag[JCMs] * B[ICMs].dot(B[JCMs])...};
    for (int ii = 0; ii < nrec; ii++) {
      real rhsnew[T::ncons] = {tdot<T::ncons>(A + ICs * T::ncons, rhs)...};
      eval_all(rhs[ICs] = rhsnew[ICs]...);
      eval_all(sol[ICs] += rhsnew[ICs]...);
    }

    eval_all(xp[Is] -= B[ICs] * invmass[Is] * Sdiag[ICs] * sol[ICs]...);
    eval_all(xp[Js] += B[ICs] * invmass[Js] * Sdiag[ICs] * sol[ICs]...);

    real rhscorr[] = {Sdiag[ICs] * (r0[ICs] - sqrt(2.0 * r0[ICs] * r0[ICs] - (xp[Is] - xp[Js]).norm2()))...};
    real solcorr[] = {rhscorr[ICs]...};
    for (int ii = 0; ii < nrec; ii++) {
      real rhsnew[T::ncons] = {tdot<T::ncons>(A + ICs * T::ncons, rhscorr)...};
      eval_all(rhscorr[ICs] = rhsnew[ICs]...);
      eval_all(solcorr[ICs] += rhsnew[ICs]...);
    }

    eval_all(f[idx[Is]] -= B[ICs] * Sdiag[ICs] * (sol[ICs] + solcorr[ICs]) * invdtfsq...);
    eval_all(f[idx[Js]] += B[ICs] * Sdiag[ICs] * (sol[ICs] + solcorr[ICs]) * invdtfsq...);
  }
  template <class T>
  void run(vec<real> *f, real rdtfsq, vec<real> *x, vec<real> *xp, real *invmass, rigid_rec &rec, int *idx) {
    do_lincs<T>(f, rdtfsq, x, xp, invmass, rec, idx, T::iaseq, T::iseq, T::jseq, T::icseq, T::icmatseq, T::jcmatseq);
  }

  template <class T>
  void run_static(vec<real> *x, vec<real> *xp, real *invmass, rigid_rec &rec, int *idx) {
    do_lincs_static<T>(x, xp, invmass, rec, idx, T::iaseq, T::iseq, T::jseq, T::icseq, T::icmatseq, T::jcmatseq);
  }
};

void lincs_setup(cellgrid_t *grid, mpp_t *mpp, real dt, real ftm2v);
void lincs_post_force(cellgrid_t *grid, mpp_t *mpp, real dt, real ftm2v);
